import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.optimize import minimize
from scipy.integrate import cumulative_trapezoid
from scipy.interpolate import interp1d
import re
import os

# --- CONSTANTS ---
C_KM_S = 299792.458
Z_STAR_MASS = 0.0003

# --- 1. LOAD DATA & PARAMS ---
print("Loading data and parameters...")

# Load Data (Raw + Refined)
raw_path = Path('data') / 'Pantheon+SH0ES.dat'
refined_path = Path('produced/refined_actual_data_1701.csv')

if not raw_path.exists():
    print(f"Error: {raw_path} not found.")
    exit(1)
if not refined_path.exists():
    print(f"Error: {refined_path} not found. Run data generation first.")
    exit(1)

# Read Raw for Forward Fit (Observed magnitudes)
df_raw = pd.read_csv(raw_path, sep=r'\s+', comment='#')
z_obs_raw = df_raw['zHD'].to_numpy()
z_clean = (1.0 + z_obs_raw) / (1.0 + Z_STAR_MASS) - 1.0
mu_obs = df_raw['MU_SH0ES'].to_numpy()

if 'MU_SH0ES_ERR_DIAG' in df_raw.columns:
    mu_err = df_raw['MU_SH0ES_ERR_DIAG'].to_numpy()
else:
    mu_err = np.ones_like(mu_obs)

mask = z_clean > 0.001
z_fit = z_clean[mask]
mu_fit = mu_obs[mask]
mu_err_fit = mu_err[mask]

# Read Inverse Parameters from produced file
param_path = Path('produced/appendix_B_fit_parameters.txt')
K_EXP_INV = 4224.0 # Default fallback
K_HILL_INV = 3232.0
ETA_HILL_INV = 1.12

if param_path.exists():
    with open(param_path, 'r') as f:
        content = f.read()
        m_exp = re.search(r"Exponential.*?K = ([\d\.]+)", content, re.S)
        if m_exp: K_EXP_INV = float(m_exp.group(1))
        
        m_hill = re.search(r"Sigmoidal.*?K = ([\d\.]+).*?eta = ([\d\.]+)", content, re.S)
        if m_hill: 
            K_HILL_INV = float(m_hill.group(1))
            ETA_HILL_INV = float(m_hill.group(2))

print(f"Inverse Parameters Loaded: Exp K={K_EXP_INV:.0f}, Hill K={K_HILL_INV:.0f}")

# --- 2. PHYSICS MODELS ---

def get_mu_vectorized(z_vals, v_func, args):
    r_max = 20000 
    r_grid = np.linspace(0, r_max, 2000)
    v_grid = v_func(r_grid, *args)
    n_grid = 1.0 + (v_grid/C_KM_S)**2
    d_opt_grid = cumulative_trapezoid(n_grid, r_grid, initial=0.0)
    z_grid = (v_grid/C_KM_S) / (1.0 - v_grid/C_KM_S)
    d_opt_interp = interp1d(z_grid, d_opt_grid, kind='linear', fill_value='extrapolate')
    d_opt_vals = d_opt_interp(z_vals)
    dL_vals = d_opt_vals * (1.0 + z_vals)
    return 5.0 * np.log10(np.maximum(dL_vals, 1e-9)) + 25.0

def v_exp(r, K):
    return C_KM_S * (1.0 - np.exp(-r/K))

def v_hill(r, K, eta):
    return C_KM_S * ( (r/K)**eta / (1.0 + (r/K)**eta) )

# --- 3. FORWARD FITTING ---

def fit_forward(model_type='exp'):
    print(f"Fitting {model_type} forward...")
    def loss(params):
        if model_type == 'exp':
            K = params[0]
            if K < 1000: return 1e9
            pred = get_mu_vectorized(z_fit, v_exp, [K])
        else:
            K, eta = params
            if K < 1000 or eta < 0.1: return 1e9
            pred = get_mu_vectorized(z_fit, v_hill, [K, eta])
        return np.sum( ((mu_fit - pred)/mu_err_fit)**2 )

    if model_type == 'exp':
        res = minimize(loss, [4400.0], method='Nelder-Mead')
    else:
        res = minimize(loss, [3300.0, 1.1], method='Nelder-Mead')
    return res.x

params_exp_fwd = fit_forward('exp')
K_EXP_FWD = params_exp_fwd[0]

params_hill_fwd = fit_forward('hill')
K_HILL_FWD, ETA_HILL_FWD = params_hill_fwd

print(f"Forward Fit: Exp K={K_EXP_FWD:.2f}, Hill K={K_HILL_FWD:.2f}")

# --- 4. PLOTTING ---
print("Generating Robustness Plot...")
os.makedirs('plots', exist_ok=True)
out_path = Path('plots') / 'robustness_comparison_final.png'

# Load Rectified Data for plotting cloud
df_rect = pd.read_csv(refined_path)
r_cloud = df_rect['r_true_Mpc'].values
v_cloud = df_rect['v0_derived_km_s'].values

# Lines
r_line = np.linspace(0, r_cloud.max()*1.1, 1000)
v_exp_inv_line = v_exp(r_line, K_EXP_INV)
v_exp_fwd_line = v_exp(r_line, K_EXP_FWD)
v_hill_inv_line = v_hill(r_line, K_HILL_INV, ETA_HILL_INV)
v_hill_fwd_line = v_hill(r_line, K_HILL_FWD, ETA_HILL_FWD)

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))

# Panel A: Exponential
ax1.scatter(r_cloud, v_cloud, s=10, c='k', alpha=0.1, label='Rectified Data')
ax1.plot(r_line, v_exp_inv_line, 'b-', lw=3, label=f'Inverse Fit (K={K_EXP_INV:.0f})')
ax1.plot(r_line, v_exp_fwd_line, 'c--', lw=3, label=f'Forward Fit (K={K_EXP_FWD:.0f})')
ax1.set_title(f'Panel A: Exponential Model Robustness')
ax1.set_ylabel('Escape Velocity $v_0$ (km/s)')
ax1.legend()
ax1.grid(alpha=0.3)

# Panel B: Sigmoidal
ax2.scatter(r_cloud, v_cloud, s=10, c='k', alpha=0.1, label='Rectified Data')
ax2.plot(r_line, v_hill_inv_line, 'r-', lw=3, label=f'Inverse Fit (K={K_HILL_INV:.0f}, $\eta$={ETA_HILL_INV:.2f})')
ax2.plot(r_line, v_hill_fwd_line, 'orange', linestyle='--', lw=3, label=f'Forward Fit (K={K_HILL_FWD:.0f}, $\eta$={ETA_HILL_FWD:.2f})')
ax2.set_title(f'Panel B: Sigmoidal (Hill) Model Robustness')
ax2.set_xlabel('True Coordinate Distance (Mpc)')
ax2.set_ylabel('Escape Velocity $v_0$ (km/s)')
ax2.legend()
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(out_path, dpi=300)
print(f"Saved {out_path}")